-
Notifications
You must be signed in to change notification settings - Fork 78
Implement Feature for Issue #379: MMD Hypothesis Test #384
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ckage - Create function signatures based on @vpratz's clarification in related issue bayesflow-org#379
…: maximum_mean_discrepency takes bf.types.Tensor and return bf.types.Tensor, but at the moment np.ndarray are provided and float return expected
…summary_network takes and return bf.types.Tensor, at the moment np.ndarray is assumed
|
Thanks a lot for the PR!
Keras offers two functions for that,
Yes, I think following the existing functions in style and signature would be good here.
Good points. I missed the naming collision, so maybe a prefix like I think we would want to follow the pattern that we provide the functions by importing them in @stefanradev93 @paul-buerkner As you are more involved in the diagnostics interface, could one of you please comment what you would prefer regarding naming/structure? |
|
Thanks for the input. The proposed changes have been pushed. I left listing of Functions and Dependencies in the module docstring for the devs. |
…om_summaries into diagnostics.metrics
- unit test output shape of compute_mmd_hypothesis_test_from_summaries
…ximator to be ContinuousApproximator due to assumption of attribute summary_network existing
…st on simple distributions like uniform and normal
…dimensions except the first one
…erence data + add corresponding test cases
|
Thank you for working on this PR! The naming overlap is indeed not ideal. Perhaps we can actually rename the plot function. From having a quick look at the code (this is one of the few diagnostics I didn't edit myself yet), it seems as if the functionality is much much general that just for MMD. Since it takes samples from one distribution and compares it to a single empirical value, it could in theory be any test statistic not just MMD. @stefanradev93 can you confirm? If true, I think we should rename the plot function to something more general and then not have the @vpratz What do you think about this suggestion? |
|
Thank you for the feedback! Technically, the functions implemented in this PR do not perform a hypothesis test themselves—they only compute the MMD values. Including "hypothesis_test" in the name might be misleading. Do you have any thoughts on alternative naming? @paul-buerkner I agree that renaming the plot function makes sense. This would free up Would you like to create an issue for tracking this, or should I go ahead and do it? |
…t on backend agnotics Tensors instead of numpy arrays
|
This was accidentally closed. We will investigate how to restore the branch and reopen PRs. |
|
Thanks for the changes! As we might change the observed data in the approximator's adapter, I'd prefer if we pass the data to |
|
@stefanradev93 @LarsKue Do we want to create a "standard" way to obtain the summary outputs from a |
This function enables easy access to the summary space. Naming can still be discussed, as well as better integration/reuse in other functions of the approximator.
- remove somewhat redundant mmd_comparison_from_summaries function - rename mmd_comparison to the more general summary_space_comparison, with configurable distance function (default MMD) - only allow calling summary_space_comparison when we can obtain the summary variables directly from the approximator. For all other use cases, directly refer to bootstrap_comparison - update tests to reflect those changes - remove redundant docstrings from the module
|
I have refactored the functions, please take a look at the individual commit descriptions for details. From my side, open questions are mainly naming-related, feel free to suggest ideas:
Tagging @paul-buerkner @stefanradev93 @LarsKue for those questions. Are you happy with those changes, or do you see room for improvement? @thegialeo |
|
Thanks for the updates — the changes look good to me!
|
|
Thanks for the feedback and good spot with the test, I have added the missing test case. |
Summary variables are the variables that get summarized. I would call the summarized variables |
- add summaries function to ModelCommparisonApproximator as well - add tests for the approximator.summaries functions
|
Thanks for the comment. I have renamed the function to |
Description
This PR introduces the first draft implementation of the MMD hypothesis test feature based on the discussion in issue #379.
Note: The implementation is not complete and should not be merged at this stage.
The key open questions and considerations are outlined below:
Related Issue
Type of Change
🚀 New feature (non-breaking change which adds functionality, no existing code was changed)
Open Questions
Clarification on
bayesflow.types.TensorConversionsnp.ndarraytobayesflow.types.Tensorwhile staying backend agnostic?bayesflow.types.Tensortonp.ndarray/float, given thatmaximum_mean_discrepancyandapproximator.summary_networkoperate withbayesflow.types.Tensor?Data Type Consistency
observed_data/observed_summariesandreference_data/reference_summariesbe of typebayesflow.types.Tensor, or should they takenp.ndarrayas arguments and cast as needed?diagnostics/metricsgenerally usenp.ndarrayorMapping[str, np.ndarray]as arguments, should we maintain this consistency across the new utility functions?Naming & Import Conventions
diagnostics/plotsalready includesmmd_hypothesis_test.py, which definesmmd_hypothesis_test(). To avoid namespace collisions, should we rename either the plot function or the metric function?from package.module import function/classoverimport package.moduleand accessingfunction/classthrough the importedpackage.module, so naming both functions the same could be problematic at some point in the future.Unit Tests
tests/test_diagnostics/test_diagnostics_metrics.pythe correct location for new unit tests for the implemented functions?